from BasicLibraries import *
import regression as rg
from matplotlib import  rc
from matplotlib.lines import Line2D
import matplotlib.patches as mpatches
import latex
plt.rcParams['xtick.major.pad']='8'
rc('text', usetex=True)
rc('font', size=24)
rc('legend', fontsize=24)
rc('text.latex', preamble=r'\usepackage{cmbright}')
rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']})
color = ['#e95057', '#0080c4', '#00a989', '#464646']

#%%
def add_exogeneous(H):
    G = H.copy()
    policies =pd.read_csv('local_data/env_policies_2.csv').rename(columns = {'index': 'mrg_cp'})
    policies['country'] = policies['mrg_cp'].apply(lambda x: x.split('-')[1])
    policies['rfyear'] = policies['mrg_cp'].apply(lambda x: x.split('-')[0])
    policies = policies.sort_values(by = ['country', 'rfyear'])
    policies['InstrTypeDiversity']=policies['InstrTypeDiversity']/policies['EnvPolicies']
    policies['DomainDiversity']/=policies['EnvPolicies']
    policies = policies.sort_values(by = ['country', 'rfyear'])
    policies[['EnvPolicies','InstrTypeDiversity','DomainDiversity']] =  \
        policies[['country', 'EnvPolicies','InstrTypeDiversity','DomainDiversity']].groupby('country').rolling(3).mean().values.tolist()
    policies = policies.drop(columns = ['rfyear', 'country']).dropna()
    
    G['mrg_cp'] = G['rfyear'].astype(int).astype(str)+'-'+G['loc']
    G = G.merge(policies, on = 'mrg_cp')

    exogeneous = list(policies.columns[1:])

    return G, exogeneous

#%% Effect of diversification
def estimate_BRD_assets(dt_ws, dummies, YFE, sdgs_options_name):
    driver_list = ['firm_size', 'inv_int', 'MTB', 'Tangibility',  'Leverage', 'Profitability', 'volatility',
                   'competition',  'DirectControl_lag',
                   'DomainDiversity', 'InstrTypeDiversity']
    dt_ws, _ = add_exogeneous(dt_ws)
    count=0
    tab = pd.DataFrame()
    res_bar = pd.DataFrame()
    exp_power = pd.DataFrame()
    for driver in driver_list:
        res, resNUM, models, dr, exp_pw = rg.make_driver_regression(dt_ws, dummies, driver, years_FE=YFE)
        exp_pw['driver'] = [driver]*len(exp_pw)
        exp_power = pd.concat((exp_power, exp_pw))
        res.index = [driver]
        resNUM[3] = (resNUM[2]-resNUM[1])/2
        resNUM.columns = ['Effect', 'LB', 'UB', 'Uncertainty']
        resNUM.index = res.columns
        resNUM = resNUM.rename(index = {'Risk mitigation': 'Risk Mit.', 'Stakeholders engagement': 'Stak.Eng.'})
        resNUM['Driver'] = [driver]*len(resNUM)
        res_bar = pd.concat((res_bar, resNUM))
        count+=1
        tab = pd.concat((tab, res))
        print(tab.to_latex())
    
    tab.loc['Assets\' characteristics'] = 'Yes'
    tab.loc['Sector fixed effects'] = 'Yes'
    tab.loc['Country fixed effects'] = 'Yes'
    tab.loc['Year fixed effects'] = ['Yes' if YFE == True else 'No'][0]
    tab.loc['Self-selectivity'] = 'Yes'
    print(tab.to_latex())
    #==
    G, a, b = [plt.cm.Greys, plt.cm.Reds_r, plt.cm.ocean]
    outer_colors = [a(0.1), a(.2), a(0.4), a(0.5), a(0.6), a(0.7), a(0.8), a(.9),
                     b(.4), b(.5), b(0.6)][::-1]
    adjust_names = {'IA': 'Information Asymmetry', 
                    'firm_size': 'Size', 
                    'inv_int': 'Inv.Intensity', 
                    'DirectControl_lag': 'GHG emissions',
                    'volatility':'Volatility',
                    'MTB': 'Market-to-Book', 'E_score': 'Environ. ratings',
                    'ESG': 'ESG ratings', 
                    'competition': 'Competition',
                    'total_effort': 'Number of initiatives',
                    'envPolDIFF': 'Policy uncertainty',
                    'DomainDiversity': 'Policy domain diversity', 
                    'InstrTypeDiversity': 'Policy Instrument diversity',
                    'EnvPolicies': 'Env. Policies'}
    res_bar = res_bar.rename(index = {'Diversification_rw': 'Behavioural Response Diversity',
                                      'Risk mitigation_rw': 'Risk mitigation',
                                      'Stakeholders engagement_rw': 'Stakeholders engagement',
                                      'Innovation_rw': 'Innovation'})
    Z = res_bar.copy()
    tmp = pd.DataFrame()
    tmpE = pd.DataFrame()
    for d in Z.Driver.unique():
        tmp = pd.concat((tmp, Z[Z.Driver == d].Effect), axis = 1)
        tmpE = pd.concat((tmpE, Z[Z.Driver == d].Uncertainty), axis = 1)
    tmp.columns =  Z.Driver.unique()
    tmpE.columns =  Z.Driver.unique()
    tmp = tmp[tmp.columns[::-1]]
    tmpE = tmpE[tmpE.columns[::-1]]
    
    ax = plt.subplot()
    tmp.iloc[::-1].rename(columns = adjust_names).plot.barh(xerr = tmpE.iloc[::-1].rename(columns = adjust_names), capsize = 4, figsize = (10,12), width = 0.85, ax = ax, 
                                                            colormap = 'tab20c', alpha = 0.85)
    plt.axvline(x = 0, lw = 0.5, c = 'k', ls = '-.')
    plt.axhline(y = 1.5, lw = 0.5, c = 'k', ls = '-.')
    plt.axhline(y = 2.5, lw = 0.5, c = 'k', ls = '-.')
    plt.axhline(y = 0.5, lw = 0.5, c = 'k', ls = '-.')
    ax.margins(y=0.0)
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], loc = 'center left', bbox_to_anchor = (1,0.5))
    plt.xlabel('Association with response diversity')

    plt.savefig('Figures/Figure3_SI.pdf', dpi = 100, bbox_inches = 'tight')


    plt.figure()   
    cols = ['#6ab8b3']*3 + ['#e3645e']*8
    ax = plt.subplot()
    tmp.iloc[0].rename(index = adjust_names).plot.barh(xerr = tmpE.iloc[0].rename(index = adjust_names), 
                                                       capsize = 4, figsize = (8,10), width = 0.85, 
                                                       ax = ax, color = '#077776', alpha = 0.75)
    plt.axvline(x = 0, lw = 0.5, c = 'k', ls = '-.')
    plt.xlabel('Association with response diversity')
    patch1 = mpatches.Patch(color='salmon', label='manual patch')  
    patch2 = mpatches.Patch(color='blue', label='manual patch')  
    plt.savefig('Figures/Figure3.pdf', dpi = 100, bbox_inches = 'tight')

    return res_bar, tab